import numpy as np
from sklearn.linear_model import LogisticRegression

from learner import Learner


class C2UpUCB(Learner):

    def __init__(self, n_treatments, vdim, baseline_model=None,
                 baseline_option='UCB', radius=2, reg=1,
                 uplift_threshold=None, explore_proba=0, explore_rng=42):
        self.n_treatments = n_treatments
        self.radius = radius
        self.uplift_threshold = uplift_threshold
        # For treatment models
        self.models = []
        self.cov_matrices = np.array([reg * np.eye(vdim) for _ in range(n_treatments)])
        self.Xs = [[] for _ in range(n_treatments)]
        self.ys = [[] for _ in range(n_treatments)]
        for _ in range(n_treatments):
            model = LogisticRegression(max_iter=1000, fit_intercept=False)
            self.models.append(model)
        # For baseline model
        self.baseline_option = baseline_option
        if baseline_option == 'constant':
            self.know_baseline = True
        elif baseline_model is None:
            self.baseline_model = LogisticRegression(max_iter=1000, fit_intercept=False)
            self.baseline_cov_matrix = reg * np.eye(vdim)
            self.baseline_X = []
            self.baseline_y = []
            self.know_baseline = False
        else:
            self.know_baseline = True
            self.baseline_model = baseline_model
            self.baseline_option = 'est'
        # For epsilon-greedy
        self.explore_proba = explore_proba
        self.explore_rng = np.random.default_rng(explore_rng)
        # For initialization phase
        self.to_pull = 1
        # Store some information
        self.rewards, self.uplifts, self.arm_his = [], [], []
        self.rewards_from_data = []
        self.ns_explore = []

    def check_visit(self):
        for i in range(self.n_treatments):
            sum_labels = np.sum(self.ys[i])
            # This means the labels are either all 1 or all 0
            if sum_labels in {0, len(self.ys[i])}:
                return i+1
        if not self.know_baseline:
            sum_labels = np.sum(self.baseline_y)
            if sum_labels in {0, len(self.baseline_y)}:
                return 0
        return -1

    def update(self, arm, contexts, feedback, uplift, reward):
        for i in range(self.n_treatments):
            contexts_i = contexts[arm == i+1]
            self.cov_matrices[i] += contexts_i.T @ contexts_i
            self.Xs[i].extend(contexts_i)
            self.ys[i].extend(feedback[arm == i+1])
            if self.to_pull >= 0:
                self.to_pull = self.check_visit()
            if self.to_pull < 0:
                self.models[i].fit(self.Xs[i], self.ys[i])
        if not self.know_baseline:
            contexts_baseline = contexts[arm == 0]
            self.baseline_cov_matrix += contexts_baseline.T @ contexts_baseline
            self.baseline_X.extend(contexts_baseline)
            self.baseline_y.extend(feedback[arm == 0])
            if self.to_pull < 0:
                self.baseline_model.fit(self.baseline_X, self.baseline_y)
        self.rewards.append(np.sum(feedback))
        self.uplifts.append(uplift)
        self.rewards_from_data.append(reward)
        self.arm_his.append(arm)

    def act(self, contexts, budget, step=None):
        n_individuals = len(contexts)
        arm = np.zeros(n_individuals, dtype=int)
        # Initialization
        if self.to_pull >= 0:
            arm[:budget] = self.to_pull
            return arm
        # Number of samples that is randomly sampled at this round
        explore_budget = np.sum(self.explore_rng.random(budget) < self.explore_proba)
        self.ns_explore.append(explore_budget)
        budget -= explore_budget
        # We can put if `budget > 0` here
        treat_UCBs = np.zeros(n_individuals)
        baseline = np.zeros(n_individuals)
        opt_treatments = np.zeros(n_individuals)
        # Compute estimates
        # Use [1] to preserve dimension"
        estimates = np.hstack([model.predict_proba(contexts)[:, [1]] for model in self.models])
        if self.baseline_option != 'constant':
            if self.know_baseline:
                baseline = self.baseline_model.predict_proba(contexts[:, :-1])[:, 1]
            else:
                baseline = self.baseline_model.predict_proba(contexts)[:, 1]
                baseline_cov_inv = np.linalg.inv(self.baseline_cov_matrix)
        cov_matrices_inv = np.array([np.linalg.inv(M) for M in self.cov_matrices])
        # Compute UCBs
        radius = self.radius if step is None else self.radius * np.log((step+1))
        for i, x in enumerate(contexts):
            radii_conf = radius * np.sqrt(x @ cov_matrices_inv @ x)
            estimates[i] += radii_conf
            if self.baseline_option == 'UCB':
                radius_conf = radius * np.sqrt(x @ baseline_cov_inv @ x)
                baseline[i] += radius_conf
            if self.baseline_option == 'LCB':
                radius_conf = radius * np.sqrt(x @ baseline_cov_inv @ x)
                baseline[i] -= radius_conf
        treat_UCBs = np.max(estimates, axis=1)
        opt_treatments = np.argmax(estimates, axis=1) + 1
        treat_UCBs = np.minimum(1, treat_UCBs)
        baseline = np.minimum(1, baseline)
        treat_UPs = treat_UCBs - baseline
        # Deal with uplift_threshold, set to 0 by default
        if self.uplift_threshold is not None:
            can_choose = treat_UPs > self.uplift_threshold
            budget = min(budget, np.sum(can_choose))
        if budget > 0:
            uplift_order = np.argsort(treat_UPs)
            selected = np.ix_(uplift_order[-budget:])
            unselected = uplift_order[:-budget]
            arm[selected] = opt_treatments[selected]
        else:
            unselected = n_individuals
        # Epsilon greedy
        if explore_budget > 0:
            explored = np.ix_(self.explore_rng.choice(unselected, explore_budget, replace=False))
            arm[explored] = self.explore_rng.integers(1, self.n_treatments+1, size=explore_budget)
        return arm
